from functools import partial
from itertools import repeat
from typing import List, Optional
import numpy as np

from Utils import logger
from Utils.Constants import FileNamesConstants, Diff as DiffConst
from Utils.utils import pkl_file_reader_gen
from FeatureMap.FeatureMapCreator import FeatureMapGenerator
from DataHandling.FileBasedDatasetBase import FileBasedDataset


class AutoEncoderSingleMapsDataset(FileBasedDataset):
    def __init__(self, files_paths: List[str], is_weights: Optional[bool], pre_load_data: bool):
        """
        Dataset for Auto Encoder from saved feature maps. Currently nan values are assigned with zeros.
        :param files_paths:
        :param is_weights:
        :param pre_load_data
        """
        if is_weights is True:
            files_paths = list(filter(lambda fn: 'weights' in fn, files_paths))
        elif is_weights is False:
            files_paths = list(filter(lambda fn: 'gradients' in fn, files_paths))
        else:
            files_paths = files_paths

        super(AutoEncoderSingleMapsDataset, self).__init__(files_paths, list(), '')
        self._is_weights = is_weights
        self._pre_load_data = pre_load_data
        self._nan_policy = self.__assign_zero_nans
        self._inf_policy = self.__fix_inf_values
        self._maps = self._load_all_maps() if self._pre_load_data else None

    @staticmethod
    def __assign_zero_nans(feature_map):
        feature_map[np.isnan(feature_map)] = 0
        return feature_map

    @staticmethod
    def __fix_inf_values(feature_map, big_val=1e7):
        feature_map[np.isinf(feature_map)] = big_val
        return feature_map

    def __pre_saved_feature_map_loader(self, feature_map_path) -> np.ndarray:
        gen = pkl_file_reader_gen(feature_map_path)
        curr_data = np.array(list(gen))
        if curr_data.shape[0] == 1:
            curr_data = curr_data[0]
        channels_first_map = np.moveaxis(curr_data, -1, 1)
        final_map = self._nan_policy(channels_first_map)
        final_map = self._inf_policy(final_map)

        return final_map

    def _load_all_maps(self):
        all_maps = list()
        for curr_file in self._files_paths:
            all_maps.extend(self.__pre_saved_feature_map_loader(curr_file))

        return all_maps

    def __getitem__(self, index):
        if self._pre_load_data:
            curr_map = self._maps[index]
        else:
            curr_file = self._files_paths[index]
            curr_map = self.__pre_saved_feature_map_loader(curr_file)
        return curr_map, curr_map

    @staticmethod
    def create_dataset(files_paths: List[str] = None, is_weights: bool = None, pre_load_data: bool = None,
                       weights_maps_base_folder: str = None, gradients_maps_base_folder: str = None, concat_data: bool = None):
        return AutoEncoderSingleMapsDataset(files_paths=files_paths, is_weights=is_weights, pre_load_data=pre_load_data)


class AutoEncoderStatsDataset(FileBasedDataset):
    def __init__(self, files_paths: List[str], is_weights: bool, is_single_layer: bool, layers_first: bool):
        """
        Dataset for Auto Encoder. This dataset doesn't work as normal dataset, instead of giving the expected line
        form the expected feature map, it loads a single feature map exhausts it and then loads another one for more
        steps. This means once a feature map is loaded all the next requested items will be the next steps of the map,
        disregarding the requested index. After a feature map is exhausted another feature map will be loaded the next
        feature map used is the next one (in files order) closest to the feature map index, if none exist is the one
        closest with a smaller index.
        This Dataset goes over all relevant files in files_paths before it exhausts a single epoch.
        :param files_paths: Paths to all files constructing the dataset (can include non-relevant files)
        :param is_weights:
        :param is_single_layer:
        :param layers_first:
        """
        if is_weights:
            files_paths = list(filter(lambda fn: FileNamesConstants.WEIGHTS_STATS in fn, files_paths))
        else:
            files_paths = list(filter(lambda fn: FileNamesConstants.GRADIENTS_STATS in fn, files_paths))
        super(AutoEncoderStatsDataset, self).__init__(files_paths, list(), '')
        self._is_weights = is_weights
        self._is_single_layer = is_single_layer
        self._layers_first = layers_first
        self._feature_map_files = files_paths
        self._dataset_size = len(files_paths) * DiffConst.NUMBER_STEPS_SAVED
        self._feature_maps_used = np.zeros((len(files_paths), ))
        self._curr_gen = repeat(np.zeros(1), times=0)
        self._curr_file = None

    def _open_new_gen(self, index):
        """
        Find the next file to open for file generator. This method prevents from using the same file twice in the same
        epoch.
        The next file selected is the first available file with index bigger than the requested file index, if none
        exists it will take the fist smaller index available file.
        :param index: object index requested by the user (not file index)
        :return:
        """
        if np.all(self._feature_maps_used):
            self._feature_maps_used[:] = 0
        file_index = index//DiffConst.NUMBER_STEPS_SAVED
        # Find closest available file start searching for bigger index file if none found go to lower index files
        available_idx = np.argwhere(self._feature_maps_used == 0).ravel()
        internal_idx = np.argwhere(available_idx >= file_index).ravel()
        use_bigger = 0
        if len(internal_idx) == 0:
            internal_idx = np.argwhere(available_idx <= file_index).ravel()
            use_bigger = -1
        idx_to_use = available_idx[internal_idx[use_bigger]]
        self._curr_file = self._feature_map_files[idx_to_use]
        logger().log('AutoEncoderStatsDataset::_open_new_gen', 'Opening new generator on ', self._curr_file)
        self._curr_gen = FeatureMapGenerator.create(self._is_weights, self._is_single_layer, self._curr_file,
                                                    layers_first=self._layers_first)
        self._feature_maps_used[idx_to_use] = 1

    def __getitem__(self, index):
        map_line = [np.nan]
        while np.any(np.isnan(map_line)):
            try:
                map_line = next(self._curr_gen)
            except StopIteration:
                logger().log('AutoEncoderStatsDataset::__getitem__', 'Finished iteration on ', self._curr_file)
                self._open_new_gen(index)
                map_line = next(self._curr_gen)
            if np.any(np.isnan(map_line)):
                logger().force_log_and_print('AutoEncoderStatsDataset::__getitem__',
                                             f'Curr file with nan values: {self._curr_file}')
                self._curr_gen = repeat(np.zeros(1), times=0)

        channels_first_map = np.moveaxis(map_line, -1, 0)
        return channels_first_map, channels_first_map

    @staticmethod
    def create_dataset(files_paths: List[str], is_weights: bool, is_single_layer: bool, layers_first: bool = True):
        """
        Create a stats dataset for a FeatureMapAutoEncoder
        :param files_paths:
        :param is_weights:
        :param is_single_layer:
        :param layers_first:
        :return:
        """
        return AutoEncoderStatsDataset(files_paths, is_weights, is_single_layer, layers_first)
